"""
Enhanced version of POAGraph for text alignment
"""

import pickle
import textwrap
from typing import Dict, Optional

import numpy as np
from tqdm import tqdm

from src.text_poa_graph_utils import path_sim_llm
from src.global_edit_utils import clean_up_text

from .new_text_alignment import TextSeqGraphAlignment
from .poa_graph import Node, POAGraph


class TextNode(Node):
    def __init__(self, nodeID=-1, text=""):
        super().__init__(nodeID, text)
        self.variations = {}  # Track alternate phrasings
        self.sequences = []  # Track sequences that contain this node
        self.influenceScore = 0
        self.num_tokens_used = 0

    def add_variation(self, text, sequence_id):
        self.variations[sequence_id] = text

    @property
    def is_stable(self):
        """A node is stable if it appears frequently enough relative to total sequences"""
        return self.frequency >= self.graph.stability_threshold


class TextPOAGraph(POAGraph):
    def __init__(self, text=None, label=-1):
        self.consensus_node_ids = []
        self._seq_paths = {}
        self.end_id = -1
        self.start_id = -1
        self.failed = False
        self.num_input_tokens_used = 0
        self.num_output_tokens_used = 0
        super().__init__(text, label)

    def addNode(self, text):
        """Override to use TextNode"""
        nid = self._nextnodeID
        newnode = TextNode(nid, text)
        self.nodedict[nid] = newnode
        self.nodeidlist.append(nid)
        self._nnodes += 1
        self._nextnodeID += 1
        self._needsSort = True
        return nid

    def addUnmatchedSeq(self, text, label=-1, updateSequences=True):
        """Modified to handle text sequences"""
        if text is None:
            return

        # Handle both string and list input
        if isinstance(text, str):
            words = text.split()
        else:
            words = text

        firstID, lastID = None, None
        neededSort = self.needsSort

        path = []
        for word in words:
            nodeID = self.addNode(word)
            if firstID is None:
                firstID = nodeID
            if lastID is not None:
                self.addEdge(lastID, nodeID, label=label)
            lastID = nodeID
            path.append(nodeID)

        self._needsort = neededSort
        if updateSequences:
            self._seqs.append(words)
            self._labels.append(label)
            self._starts.append(firstID)
            self._seq_paths[label] = path

        return firstID, lastID

    def add_text(self, text, label=-1):
        """Main method to add new text to the alignment"""
        if len(self._seqs) == 0:
            # First sequence - just add it
            self.addUnmatchedSeq(text, label)
        else:
            # Align to existing graph
            alignment = TextSeqGraphAlignment(
                text, self, matchscore=2, mismatchscore=-1, gapscore=-2
            )
            self.incorporateSeqAlignment(alignment, text, label)

        # Update node frequencies
        self._update_frequencies()

    def removeNode(self, nodeID):
        """Override to handle text nodes"""
        node = self.nodedict[nodeID]
        if node is None:
            return

        # Remove all edges to this node
        out_edges = node.outEdges.copy()
        in_edges = node.inEdges.copy()

        for edge in out_edges:
            self.removeEdge(node.ID, edge)
        for edge in in_edges:
            self.removeEdge(edge, node.ID)

        # Remove from graph
        del self.nodedict[nodeID]
        self.nodeidlist.remove(nodeID)

        for path in self._seq_paths.values():
            if nodeID in path:
                path.remove(nodeID)

        self._nnodes -= 1
        self._needsSort = True

    def removeEdge(self, nodeID1, nodeID2):
        """Override to handle text nodes"""
        node1 = self.nodedict[nodeID1]
        node2 = self.nodedict[nodeID2]

        if node1 is None or node2 is None:
            return

        # Remove from graph
        del node1.outEdges[nodeID2]
        del node2.inEdges[nodeID1]

    def merge_consensus_nodes(self, verbose: bool = False):
        self.toposort()
        # reset consensus node ids
        self.consensus_node_ids = []
        nodes = list(self.nodeiterator()())
        consensus_segments = []
        i = 0
        while i < len(nodes):
            node = nodes[i]
            out_weight = sum(e.weight for e in node.outEdges.values())
            in_weight = sum(e.weight for e in node.inEdges.values())

            if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]:
                consensus_segment = [(node.ID, node.text)]
                next_node = node
                while (i + 1) < len(nodes) and len(next_node.outEdges) == 1:
                    next_node = nodes[i + 1]
                    next_out_weight = sum(e.weight for e in next_node.outEdges.values())
                    next_in_weight = sum(e.weight for e in next_node.inEdges.values())

                    if (
                        next_out_weight != self.num_sequences
                        or next_in_weight != self.num_sequences
                    ):
                        break

                    consensus_segment.append((next_node.ID, next_node.text))
                    i += 1
                consensus_segments.append(consensus_segment)
            i += 1
        # merge consensus nodes into a single node
        for segment in consensus_segments:
            if len(segment) == 1:
                self.consensus_node_ids.append(segment[0][0])
                continue
            merged_text = " ".join([text for _, text in segment])
            first_node_id = segment[0][0]
            last_node_id = segment[-1][0]

            self.nodedict[last_node_id].text = merged_text
            self.consensus_node_ids.append(last_node_id)

            # attach all incoming edges to first node to last node
            for id, edge in self.nodedict[first_node_id].inEdges.items():
                weight = edge.weight
                for _ in range(weight):
                    self.addEdge(id, last_node_id, label=edge.labels)

            # delete all nodes except last node
            for node_id, _ in segment[:-1]:
                self.removeNode(node_id)

            

        if verbose:
            print(self.consensus_node_ids)

    """
    find all paths between start_node_id and end_node_id from original sequences
    return a list of dictionaries with the following keys:
    - path: list of node ids in the path (excluding start and including end)
    - text: text of the path (excluding start and end)
    - weight: minimal edge weight across all edges in the path
    - labels: intersection of all edge labels in the path
    """

    def find_paths_between(self, start_node_id: int, end_node_id: int):
        # find all paths between start_node_id and end_node_id from original sequences
        path_dicts = []

        # keep track of visited paths to avoid duplicates
        visited_paths = set()

        for _, path in self._seq_paths.items():
            start_index = path.index(start_node_id) if start_node_id in path else None
            end_index = path.index(end_node_id) if end_node_id in path else None

            # print(start_index, end_index)
            # print(path)

            if (
                start_index is not None
                and end_index is not None
                and end_index - start_index > 1
                and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths
            ):
                # intersection of all edge labels in the path
                path_labels = set.intersection(
                    *[
                        set(self.nodedict[next_node_id].inEdges[node_id].labels)
                        for node_id, next_node_id in zip(
                            path[start_index:end_index], path[start_index + 1 : end_index + 1]
                        )
                    ]
                )
                path_weight = len(path_labels)
                path_dicts.append(
                    {
                        
                        "path": path[start_index + 1 : end_index + 1],
                        "body_text": " ".join(
                            [
                                self.nodedict[node_id].text
                                for node_id in path[start_index + 1 : end_index]
                            ]
                        ),
                        "begin_text": self.nodedict[path[start_index]].text,
                        "end_text": self.nodedict[path[end_index]].text,
                        "weight": path_weight,
                        "labels": path_labels,
                    }
                )
                visited_paths.add(tuple(path[start_index + 1 : end_index + 1]))

        return path_dicts

    def _follow_path(self, start_id):
        """Follow all possible paths from a node"""
        paths = []
        visited = set()

        def dfs(node_id, current_path):
            if node_id in visited:
                return
            visited.add(node_id)
            node = self.nodedict[node_id]

            if not node.outEdges:
                paths.append(current_path + [node_id])
                return

            for next_id in node.outEdges:
                dfs(next_id, current_path + [node_id])

        dfs(start_id, [])
        return paths

    def merge_paths_between(
        self,
        start_node_id: int,
        end_node_id: int,
        path_sim_type: str = "llm",
        verbose: bool = False,
        **kwargs,
    ):
        path_dicts = self.find_paths_between(start_node_id, end_node_id)

        if path_sim_type == "llm":
            api = kwargs.get("api", "openai")
            model = kwargs.get("model", "gpt-4o-mini")
            domain = kwargs.get("domain", None)
            similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None)

            def path_sim_func(path1_text, path2_text):
                return path_sim_llm(
                    path1_text,
                    path2_text,
                    api=api,
                    model=model,
                    domain=domain,
                    custom_similarity_judge_prompt=similarity_judge_prompt,
                )

        elif path_sim_type == "cosine":
            pass
            # embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
            # threshold = kwargs.get("threshold", 0.9)
            # path_sim_func = path_sim_cosine(embedding_model, threshold)
        else:
            raise ValueError(f"Invalid path similarity type: {path_sim_type}")

        # merge paths based on semantic similarity
        path_equivalence_classes = {}
        class_count = 0

        for path_dict in path_dicts:
            if verbose:
                print(path_dict)
            found_class = False
            for _, eq_class in path_equivalence_classes.items():
                # check if path dict is already in an equivalence class
                path1_text = (
                    path_dict["begin_text"]
                    + " "
                    + path_dict["body_text"]
                    + " "
                    + path_dict["end_text"]
                )
                path2_text = (
                    eq_class[0]["begin_text"]
                    + " "
                    + eq_class[0]["body_text"]
                    + " "
                    + eq_class[0]["end_text"]
                )

                judgement, num_input_tokens, num_output_tokens = path_sim_func(
                    path1_text, path2_text
                )
                self.num_input_tokens_used += num_input_tokens
                self.num_output_tokens_used += num_output_tokens
                if judgement:
                    eq_class.append(path_dict)
                    found_class = True
                    break
            if not found_class:
                class_count += 1
                path_equivalence_classes[class_count] = [path_dict]

        nodes_to_remove = set()  # Track nodes to remove
        for _, eq_class in path_equivalence_classes.items():
            path_dict = eq_class[0]

            if verbose:
                print(eq_class)
            # add new node with merged text
            new_node_id = self.addNode(path_dict["body_text"])
            for sequence_id in path_dict["labels"]:
                self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]

            # collect nodes to remove from first path
            nodes_to_remove.update(path_dict["path"][:-1])

            # process data regarding weights and labels
            labels = list(path_dict["labels"])
            weight = path_dict["weight"]
            self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)

            # Updated seq_paths for all labels to include new_node betwwen start_node and end_node
            for label in labels:
                index = self._seq_paths[label].index(start_node_id)
                if (
                    index + 1 < len(self._seq_paths[label])
                    and self._seq_paths[label][index + 1] != new_node_id
                ):
                    self._seq_paths[label].insert(index + 1, new_node_id)

            self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)

            self.nodedict[new_node_id].sequences = labels
            # process additional paths
            for path_dict in eq_class[1:]:
                for sequence_id in path_dict["labels"]:
                    self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
                nodes_to_remove.update(path_dict["path"][:-1])

                # copy incoming edges to new node
                labels = list(path_dict["labels"])
                weight = path_dict["weight"]
                self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)

                # Updated seq_paths for all labels to include new_node betwwen start_node and end_node
                for label in labels:
                    index = self._seq_paths[label].index(start_node_id)
                    if (
                        index + 1 < len(self._seq_paths[label])
                        and self._seq_paths[label][index + 1] != new_node_id
                    ):
                        self._seq_paths[label].insert(index + 1, new_node_id)

                self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
                self.nodedict[new_node_id].sequences.extend(labels)

            self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences))

        # Remove all collected nodes after processing
        for node_id in nodes_to_remove:
            if node_id in self.nodedict:
                if verbose:
                    print(f"Removing node {node_id}")
                self.removeNode(node_id)

    def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs):
        # add dummy end node to the end of the graph
        if not self.consensus_node_ids:
            self.merge_consensus_nodes(verbose=verbose)

        self.toposort()

        if self.start_id == -1:
            if verbose:
                print("Adding start node")
            self.start_id = self.addNode(text="START")
            self._nextnodeID += 1
            self.consensus_node_ids.insert(0, self.start_id)

            for label, path in self._seq_paths.items():
                self.addEdge(self.start_id, path[0], label=label, weight=1)
                path.insert(0, self.start_id)

        if self.end_id == -1:
            if verbose:
                print("Adding end node")
            self.end_id = self.addNode(text="END")
            self._nextnodeID += 1
            self.consensus_node_ids = self.consensus_node_ids + [self.end_id]

            for label, path in self._seq_paths.items():
                self.addEdge(path[-1], self.end_id, label=label, weight=1)
                path.append(self.end_id)

        for i in tqdm(range(len(self.consensus_node_ids) - 1)):
            if verbose:
                print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1])
            self.merge_paths_between(
                self.consensus_node_ids[i],
                self.consensus_node_ids[i + 1],
                path_sim_type=path_sim_type,
                verbose=verbose,
                **kwargs,
            )

    def get_variable_node_ids(self):
        return [
            node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids
        ]

    def compress_paths_between(self, start_node_id: int, end_node_id: int):
        pass

    def compress_graph(self):
        pass

    def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2):
        self.toposort()
        direct_scores = []
        for node in self.nodedict.values():
            next_out_weight = sum(e.weight for e in node.outEdges.values())
            next_in_weight = sum(e.weight for e in node.inEdges.values())
            if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences:
                out_list = []
                for edge in node.outEdges.values():
                    for _ in range(len(set(edge.labels))):
                        out_list.append(np.mean([outcome[label] for label in set(edge.labels)]))
                direct_scores.append((node.ID, np.var(out_list)))

        scores = direct_scores.copy()

        # Start from the end and propagate influence backward
        for i in range(len(scores) - 2, -1, -1):
            # Current node gets its direct influence plus discounted influence of next node
            current_direct = scores[i][1]
            next_total = scores[i + 1][1]
            scores[i] = (scores[i][0], current_direct + discount_factor * next_total)

        scores.sort(key=lambda x: x[1], reverse=True)
        return scores

    def jsOutput(
        self,
        verbose: bool = False,
        annotate_consensus: bool = True,
        color_annotations: Dict[int, str] = None,
    ):
        """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""

        # get the consensus sequence, which we'll use as the "spine" of the
        # graph
        pathdict = {}
        if annotate_consensus:
            path, __, __ = self.consensus()
        lines = ["var nodes = ["]

        ni = self.nodeiterator()
        count = 0
        for node in ni():
            title_text = ""
            if node.sequences:
                title_text += f"Sequences: {node.sequences}"
            if node.variations:
                title_text += ";;;".join(
                    [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
                )
                title_text = title_text.replace('"', "'")
            line = (
                "    {id:"
                + str(node.ID)
                + ', label: "'
                + str(node.ID)
                + ": "
                + node.text.replace('"', "'")
                + '", title: '
                + '"'
                + title_text
                + '",'
            )
            if color_annotations and node.ID in color_annotations:
                line += f" color: '{color_annotations[node.ID]}', "
            if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
                line += (
                    ", x: "
                    + str(pathdict[node.ID])
                    + ", y: 0 , fixed: { x:true, y:false},"
                    + "color: '#7BE141', is_consensus:true},"
                )
            else:
                line += "},"
            lines.append(line)

        lines[-1] = lines[-1][:-1]
        lines.append("];")

        lines.append(" ")

        lines.append("var edges = [ ")
        ni = self.nodeiterator()
        for node in ni():
            nodeID = str(node.ID)
            for edge in node.outEdges:
                target = str(edge)
                weight = str(node.outEdges[edge].weight + 1.5)
                lines.append(
                    "    {from: "
                    + nodeID
                    + ", to: "
                    + target
                    + ", value: "
                    + weight
                    + ", color: '#4b72b0', arrows: 'to'},"
                )
            if verbose:
                for alignededge in node.alignedTo:
                    # These edges indicate alignment to different bases, and are
                    # undirected; thus make sure we only plot them once:
                    if node.ID > alignededge:
                        continue
                    target = str(alignededge)
                    lines.append(
                        "    {from: "
                        + nodeID
                        + ", to: "
                        + target
                        + ', value: 1, style: "dash-line", color: "red"},'
                    )

        lines[-1] = lines[-1][:-1]
        lines.append("];")
        return lines

    def htmlOutput(
        self,
        outfile,
        verbose: bool = False,
        annotate_consensus: bool = True,
        color_annotations: Dict[int, str] = None,
    ):
        header = """
                  <!doctype html>
                  <html>
                  <head>
                    <title>POA Graph Alignment</title>

                    <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
                  </head>

                  <body>

                  <div id="loadingProgress">0%</div>

                  <div id="mynetwork"></div>

                  <script type="text/javascript">
                    // create a network
                  """
        outfile.write(textwrap.dedent(header[1:]))
        lines = self.jsOutput(
            verbose=verbose,
            annotate_consensus=annotate_consensus,
            color_annotations=color_annotations,
        )
        for line in lines:
            outfile.write(line + "\n")
        footer = """
                  var container = document.getElementById('mynetwork');
                  var data= {
                    nodes: nodes,
                    edges: edges,
                  };
                  var options = {
                    width: '100%',
                    height: '800px',
                    physics: {
                        enabled: false,
                        stabilization: {
                            updateInterval: 10,
                        },
                    },
                    edges: {
                        color: {
                            inherit: false
                        }
                    },
                    layout: {
                        hierarchical: {
                            direction: "UD",
                            sortMethod: "directed",
                            shakeTowards: "roots",
                            levelSeparation: 150, // Adjust as needed
                            nodeSpacing: 800, // Adjust as needed
                            treeSpacing: 200, // Adjust as needed
                            parentCentralization: true,
                        }
                    }
                  };
                  var network = new vis.Network(container, data, options);
                  
                  network.on('beforeDrawing', function(ctx) {
                    nodes.forEach(function(node) {
                        if (node.isConsensus) {
                            // Set the level of spine nodes to the bottom
                            network.body.data.nodes.update({
                                id: node.id,
                                level: 0 // Set level to 0 for spine nodes
                            });
                        }
                    });
                });

                  network.on("stabilizationProgress", function (params) {
                    document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
                  });
                  network.once("stabilizationIterationsDone", function () {
                      document.getElementById("loadingProgress").innerText = "100%";
                      setTimeout(function () {
                        document.getElementById("loadingProgress").style.display = "none";
                      }, 500);
                  });

                
                </script>

                </body>
                </html>
                """
        outfile.write(textwrap.dedent(footer))

    
    def multi_consensus_response(self, abstention_threshold: Optional[float] = None, filter: bool = True):
        self.toposort()
        nodesInReverse = self.nodeidlist[::-1]
        maxnodeID = self.end_id
        nextInPath = [-1] * maxnodeID
        scores = np.zeros(len(self.nodeidlist))

        id_to_index = {node_id: index for index, node_id in enumerate(self.nodeidlist)}
        index_to_id = {index: node_id for index, node_id in enumerate(self.nodeidlist)}

        for nodeID in nodesInReverse:
            bestWeightScoreEdges = [(-1, -1, None)]
            for neighbourID in self.nodedict[nodeID].outEdges:
                # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
                e = self.nodedict[nodeID].outEdges[neighbourID]
                weightScoreEdge = (e.weight, scores[id_to_index[neighbourID]], neighbourID)
                
                
                if weightScoreEdge > bestWeightScoreEdges[0]:
                    bestWeightScoreEdges = [weightScoreEdge]
                elif weightScoreEdge == bestWeightScoreEdges[0] and filter:
                    bestWeightScoreEdges.append(weightScoreEdge)

            
            scores[id_to_index[nodeID]] = sum(bestWeightScoreEdges[0][0:2])
            if bestWeightScoreEdges[0][2] is not None:
                nextInPath[id_to_index[nodeID]] = id_to_index[bestWeightScoreEdges[0][2]]
            else:
                nextInPath[id_to_index[nodeID]] = None

        pos = np.argmax(scores)
        path = []
        text = []
        labels = []

        while pos is not None and pos > -1:
            if abstention_threshold is not None and self.nodedict[index_to_id[pos]].variations:
                if (
                    len(self.nodedict[index_to_id[pos]].labels) / self.num_sequences
                    >= abstention_threshold
                ):
                    path.append(index_to_id[pos])
                    labels.append(self.nodedict[index_to_id[pos]].labels)
                    text.append(self.nodedict[index_to_id[pos]].text)
            else:
                path.append(index_to_id[pos])
                labels.append(self.nodedict[index_to_id[pos]].labels)
                text.append(self.nodedict[index_to_id[pos]].text)
            pos = nextInPath[pos]

        # ignore END node
        path = path[:-1]
        # ignore END node
        text = text[:-1]
        # ignore START in text
        text[0] = text[0].replace("START", "")
        labels = labels[:-1]

        return " ".join(text)
    

    def consensus_response(
        self, selection_threshold: Optional[float] = 0.5, api: str = "openai" , model: str = "gpt-4o-mini", task: str = "bio", **kwargs
    ) -> str:
        self.toposort()

        consensus_node_ids = self.consensus_node_ids
        print(consensus_node_ids)

        selected_node_ids = []

        for node_id in consensus_node_ids:
            if node_id == self.start_id or node_id == self.end_id:
                continue

            selected_node_ids.append(node_id)

            for neighbor_id in self.nodedict[node_id].outEdges:
                if neighbor_id in consensus_node_ids:
                    continue

                if (
                    len(self.nodedict[neighbor_id].labels) / self.num_sequences
                    >= selection_threshold
                ):
                    selected_node_ids.append(neighbor_id)

        text = " ".join([self.nodedict[node_id].text for node_id in selected_node_ids])
        print(text)
        cleaned_text = clean_up_text(text, task=task, api=api, model=model, **kwargs)
        return cleaned_text

    def save_to_pickle(self, filename):
        with open(filename, "wb+") as f:
            pickle.dump(self, f)

    def refine_graph(
        self,
        verbose: bool = False,
        save_intermediate_file: str = None,
        final_merge: bool = True,
        **kwargs,
    ):
        self.merge_consensus_nodes(verbose=verbose)

        if save_intermediate_file:
            with open(save_intermediate_file, "w+") as f:
                self.htmlOutput(f, annotate_consensus=False)

        if not self.consensus_node_ids:
            self.failed = True
            return

        else:
            self.merge_divergent_paths(verbose=verbose, **kwargs)

            if final_merge:
                try:
                    self.merge_consensus_nodes(verbose=verbose)
                except Exception as e:
                    print(e)
                    self.failed = True
